#!/bin/bash
#SBATCH --ntasks-per-node=1
#SBATCH --partition=hopper-prod
#SBATCH --requeue
#SBATCH --mail-type=ALL
#SBATCH --mail-user=andres.marafioti@huggingface.co

# ----------------- Set up environment -----------------
set -x -e

source /fsx/m4/start-m4-user
conda activate base
conda activate $CONDA_ENV_NAME

pushd $WORKING_DIR

#python -c 'import torch; cuda=torch.version.cuda; assert cuda.startswith("11"), f"cuda-11.x is needed for bf16, got {cuda}"'


# export WANDB_MODE=offline
export WANDB_DIR=/fsx/m4/experiments
export TOKENIZERS_PARALLELISM=false

GIT_PYTHON_GIT_EXECUTABLE=`which git`
export GIT_PYTHON_GIT_EXECUTABLE

export AWS_MAX_ATTEMPTS=20
export PYTHONPATH=$WORKING_DIR:$PYTHONPATH
# ------------------------------------------------------

# ----------------- Define paths -----------------
BASE_S3_PATH="s3://m4-datasets-us-east-1"
BASE_S3_PATH_OCR="s3://pixparse-datasets-us-east-1"

BASE_PATH_DATA_SFT="$BASE_S3_PATH/SFT/mix_11/prepared_docmatix_dataset_full_250K_prepared_mathwriting_dataset_1_0_mix_9_unwrapped_1_0"

SAVE_DIR="/fsx/m4/experiments/local_experiment_dir/$RUN_NAME"
ACCELERATE_CONFIG_FILE="$SAVE_DIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated"
DEEPSPEED_CONFIG_FILE="$SAVE_DIR/${SLURM_JOB_ID}_ds_config.json.autogenerated"
CONFIG_FILE="$TRAINING_CONFIGS_DIR/config.yaml"
# -------------------------------------------------

# ----------------- Create accelerate config -----------------
# Auto-generate the accelerate config
NUM_GPUS=$((NUM_GPUS_PER_NODE*SLURM_NNODES))
MASTER_ADDR=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1`
# From https://i.hsfzxjy.site/2021-03-10-obtain-a-random-unused-tcp-port-with-bash/
function unused_port() {
    N=${1:-1}
    comm -23 \
        <(seq "1025" "65535" | sort) \
        <(ss -Htan |
            awk '{print $4}' |
            cut -d':' -f2 |
            sort -u) |
        shuf |
        head -n "$N"
}
MASTER_PORT=$(unused_port)


cat << EOT > $ACCELERATE_CONFIG_FILE
# WARNING: do not edit this file as this is an slurm-auto-generated file
compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_multinode_launcher: standard
  deepspeed_config_file: $DEEPSPEED_CONFIG_FILE
  zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: $MASTER_ADDR
main_process_port: $MASTER_PORT
main_training_function: main
num_machines: $SLURM_NNODES
num_processes: $NUM_GPUS
use_cpu: false
EOT
# -------------------------------------------------

# ----------------- Create deepspeed config -----------------
# Auto-generate the DS config
cat << EOT > $DEEPSPEED_CONFIG_FILE
{
    "communication_data_type": "fp32",
    "bf16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": false,
        "reduce_bucket_size": 1e8,
        "contiguous_gradients": true,
        "stage3_gather_16bit_weights_on_model_save": false,
        "stage3_prefetch_bucket_size": 1e8,
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 2e9,
        "stage3_max_reuse_distance": 2e9,
        "offload_optimizer": {
            "device": "none"
        },
        "offload_param": {
            "device": "none"
        }
    },
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "steps_per_print": 2000000
}
EOT
# -------------------------------------------------

# ----------------- Create commands to access data -----------------

# Get SFT code shards
all_sft_train_shards=$(aws s3 ls --recursive $BASE_PATH_DATA_SFT/ | grep "\.tar"| cut -c32- | uniq)
all_sft_train_shards=$(printf -- "pipe:bash ${WORKING_DIR}/experiments/pretraining/vloom/common/webdataset_get_file.sh $BASE_S3_PATH/%s \n" $all_sft_train_shards)


ALL_SFT_TRAIN_SHARDS_TXT_FILE="$SAVE_DIR/all_sft_train_shards.txt"
printf "%s\n" "${all_sft_train_shards[@]}" > $ALL_SFT_TRAIN_SHARDS_TXT_FILE

# -------------------------------------------------

# ----------------- Create commands to launch training -----------------
pip freeze > $SAVE_DIR/${SLURM_JOB_ID}_requirements.txt


# Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable
# using python -u to force unbuffered real time logging
export LAUNCHER="python -u -m accelerate.commands.launch \
    --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
    --config_file $ACCELERATE_CONFIG_FILE \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --role \$(hostname -s): --tee 3 \
    "
export PROGRAM="m4/training/main.py \
        --config $CONFIG_FILE \
        --sft.training_datasets_paths $ALL_SFT_TRAIN_SHARDS_TXT_FILE \
        --job_id $SLURM_JOB_ID \
        --jz_job_time_sec $JOB_TIME_SEC \
        --save_dir $SAVE_DIR \
    "
export M4_DATA_CMD="rm -rf /scratch/m4data && mkdir -p /scratch/m4data && "

export CMD="$M4_DATA_CMD $LAUNCHER $PROGRAM"


#export NCCL_ALGO=Ring

# remove me
#if [[ $SLURM_PROCID == "0" ]]; then
#  export NCCL_DEBUG=INFO
#  export NCCL_DEBUG_SUBSYS=COLL
#fi


# makes everything very slow
#export CUDA_LAUNCH_BLOCKING=1
        # --laion.training_datasets_paths $ALL_LAION_TRAIN_SHARDS_TXT_FILE \
        # --wiki.training_datasets_paths $ALL_WIKI_TRAIN_SHARDS_TXT_FILE \
# force crashing on nccl issues like hanging broadcast
#export NCCL_ASYNC_ERROR_HANDLING=1

echo $CMD

printenv | sort > $SAVE_DIR/logs/printenv.$SLURM_JOB_ID

# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
#    --unbuffered \
SRUN_ARGS=" \
    --label \
    --wait=60 \
    --kill-on-bad-exit=1 \
    "
# ---------------------------------------------------------
# ----------------- Launch training -----------------

#PER_RANK_LOGS=$SAVE_DIR/logs/ranks
#mkdir -p $PER_RANK_LOGS
#srun --output=$PER_RANK_LOGS/log.%j-%N.txt  $SRUN_ARGS --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt

srun $SRUN_ARGS --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
